package org.example.reactive.demo3.handler;

import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.example.reactive.demo3.entity.Message;
import org.springframework.stereotype.Component;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

/**
 * @author Gong.Yang
 */
@Component
@Slf4j
public class ChatSocketHandler implements WebSocketHandler {
    private static final Map<String, WebSocketSession> userMap = new ConcurrentHashMap<>();

    @Override
    public Mono<Void> handle(WebSocketSession session) {
        String query = session.getHandshakeInfo().getUri().getQuery();
        Map<String, String> queryMap = getQueryMap(query);
        String userId = queryMap.getOrDefault("id", "");
        userMap.put(userId, session);
        return session.receive().flatMap(webSocketMessage -> {
            String payload = webSocketMessage.getPayloadAsText();
            Message message = JSONUtil.toBean(payload, Message.class);
            String targetId = message.getTargetId();
            if (userMap.containsKey(targetId)) {
                WebSocketSession targetSession = userMap.get(targetId);
                if (null != targetSession) {
                    WebSocketMessage textMessage = targetSession.textMessage(message.getMessageText());
                    return targetSession.send(Mono.just(textMessage));
                }
            }
            return session.send(Mono.just(session.textMessage("目标用户不在线")));
        }).then().doFinally(signal -> userMap.remove(userId)); // 用户关闭连接后删除对应连接
    }

    private Map<String, String> getQueryMap(String queryStr) {
        Map<String, String> queryMap = new HashMap<>();
        if (!StrUtil.isEmpty(queryStr)) {
            String[] queryParam = queryStr.split("&");
            Arrays.stream(queryParam).forEach(s -> {
                String[] kv = s.split("=", 2);
                String value = kv.length == 2 ? kv[1] : "";
                queryMap.put(kv[0], value);
            });
        }
        return queryMap;
    }
}
